#!/bin/bash

# author: sunning
# created: 2022年11月24日9:36:18
# updated: 2022年12月3日18:50:44
# desc: 该脚本用于Alphafold2分布式训练作业代码模式运行脚本
#       所用数据集路径为启pod时默认挂的oss路径
# 该脚本接收3个参数
# --output_dir 结果输出目录
# --gpus 该作业使用的gpu卡数（每个节点）
# --num_nodes 训练作业使用的训练节点数

#export NCCL_IB_DISABLE=0
#export NCCL_IB_HCA=mlx5_5
#export NCCL_IB_GID_INDEX=3
#export NCCL_IB_TC=106

input_dir='/dros/common/ecosystem/MEDICINE_COMPUTING/apps/AlphaFold/EXAMPLE/AlphaFold_CT/input_dir/'
train_data_dir=$input_dir/train_data_dir
train_alignment_dir=$input_dir/train_alignment_dir
template_mmcif_dir=$input_dir/template_mmcif_dir
chain_data_cache=$input_dir/chain_data_cache.json
mmcif_cache=$input_dir/mmcif_cache.json

if [ -z $output_dir ] ; then
  echo "please input valid output_dir!"
  exit 1
fi

if [ ! -d $train_data_dir ] || [ ! -f $chain_data_cache ] ; then
  train_data_dir='/dros/common/public/dataset/pdb_mmcif/mmcif_files/'
  chain_data_cache='chain_data_cache.json'
fi

if [ ! -d $train_alignment_dir ] ; then
  train_alignment_dir='/dros/common/temp/self-define/yangfei/jyf/preprocess_data/pdb/'
fi

if [ ! -d $template_mmcif_dir ] || [ ! -f $mmcif_cache ] ; then
  template_mmcif_dir='/dros/common/public/dataset/pdb_mmcif/mmcif_files/'
  mmcif_cache='mmcif_cache.json'
fi

for line in `cat /etc/volcano/master.host`;do until nslookup $line; do echo waiting for myservice; sleep 2; done; done;
for line in `cat /etc/volcano/worker.host`;do until nslookup $line; do echo waiting for myservice; sleep 2; done; done;
for line in `cat /etc/volcano/master.host`;do ssh $line hostname -I > /etc/mpi/hostfile; done;
for line in `cat /etc/volcano/worker.host`;do ssh $line hostname -I >> /etc/mpi/hostfile; done;
python new_hostfile.py /etc/mpi/hostfile /etc/mpi/hostfile_v2 $gpus

NCCL_DEBUG=INFO deepspeed --hostfile=/etc/mpi/hostfile_v2 \
    train_openfold.py $train_data_dir $train_alignment_dir $template_mmcif_dir $output_dir 2021-10-10 \
    --template_release_dates_cache_path $mmcif_cache \
    --precision 32 \
    --gpus $gpus \
    --num_nodes $num_nodes \
    --replace_sampler_ddp=True \
    --seed 42 \
    --deepspeed_config_path deepspeed_config.json \
    --checkpoint_every_epoch \
    --resume_model_weights_only True \
    --train_chain_data_cache_path $chain_data_cache \
    --obsolete_pdbs_file_path /dros/common/public/dataset/pdb_mmcif/obsolete.dat
